import numpy as np
import cv2
import numbers
import os

from torchvision import transforms
from PIL import Image
from .tokenizers import Tokenizer

def load_bert_npy(image_path):
    
    image = np.load(image_path)

    return image

def load_bertcls_npy(image_path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], size=(224,224), reverse=False):
    
    token_ids = np.load(image_path)
    input_tensor = np.zeros([1,128], dtype=np.int64)
    if token_ids.shape[1] > 128:
        input_tensor= token_ids[:,:128] 
    else:
        input_tensor[:,:token_ids.shape[1]] = token_ids

    return input_tensor

def load_bert_sst2(sen_label, tokenizer):
    sentence = " ".join(sen_label[:-1])
    results = tokenizer.tokenize(sentence, maxlen=64)
    token_ids = np.array(tokenizer.tokens_to_ids(results))
    token_type_ids = np.array([0] * len(token_ids))
    input_token_ids = np.zeros([64], dtype=np.int64)
    input_token_type_ids = np.zeros([64], dtype=np.int64)
    if token_ids.shape[0] > 64:
        input_token_ids = token_ids[:64] 
        input_token_type_ids = token_type_ids[:64]
    else:
        input_token_ids[:token_ids.shape[0]] = token_ids
        input_token_type_ids[:token_type_ids.shape[0]] = token_type_ids

    return input_token_ids, input_token_type_ids

def load_npy(image_path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], size=(224,224), reverse=False):
    
    image = np.load(image_path)

    return image

def base_preprocess(image_path, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], size=(224,224), reverse=False):
    image = cv2.imread(image_path)
    if reverse:
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    imgcrop = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
    imgcrop = imgcrop.astype(np.float32)
    imgcrop -= mean
    imgcrop /= std

    imgcrop = imgcrop.transpose([2, 0, 1])
    # np.save("comp.npy",imgcrop)
    return imgcrop

def get_images(image_list, params):
    # The mean, std value is in the prototxt, if not, default is 0.0/1.0
    if "scale" in params.keys():
        std = tuple([ 1/float(scale) for scale in params['scale'].split(",") ])
    elif "std" in params.keys():
        std = tuple([float(scale) for scale in params['std'].split(",") ])
    else:
        std = 1.0
    
    if "mean" in params.keys():
        mean = tuple([float(mean) for mean in params['mean'].split(",")]) 
    else:
        mean = 0.

    reverse = True if "reverse_channel" in params.keys() else False
    
    
    if "pre_process" in params.keys():
        preprocess = eval(params["pre_process"])
    else:
        preprocess = base_preprocess

    dir_pre = os.path.split(params["image_T4"])[0]+"/"
    if params["pre_process"] == "load_bert_npy":
        size= [[int(s) for s in size.split(",")] for size in params["input_size"].split("#")]
        inputs = [[preprocess(dir_pre+im_f[i]) for i in range(len(im_f))] for im_f in image_list]
        input_ = [np.zeros((len(image_list), *size[0]), dtype=np.int32), np.zeros((len(image_list), *size[1]), dtype=np.int32),\
            np.zeros((len(image_list), *size[2]), dtype=np.int32)]
        for ix in range(len(image_list)):
            input_[0][ix] = inputs[ix][0]
            input_[1][ix] = inputs[ix][1]
            input_[2][ix] = inputs[ix][2]
        return input_
    elif params["pre_process"] == "load_bert_sst2":
        tokenizer = Tokenizer(params["vocab"], )
        size= [[int(s) for s in size.split(",")] for size in params["input_size"].split("#")]
        inputs = [preprocess(im_f, tokenizer) for im_f in image_list]
        input_ = [np.zeros((len(image_list), *size[0]), dtype=np.int64), np.zeros((len(image_list), *size[1]), dtype=np.int64)]
        for ix in range(len(image_list)):
            input_[0][ix] = inputs[ix][0]
            input_[1][ix] = inputs[ix][1]
        return input_
    else:
        size= [int(s) for s in params["input_size"].split(",")]
        inputs = [preprocess(dir_pre+im_f[0],  mean, std, size, reverse) for im_f in image_list]
        if params["post_process"] == "get_textcnn_post" or params["post_process"] == "get_bertcls_post":
            input_ = np.zeros((len(image_list), *size), dtype=np.int64)#
        else:
            input_ = np.zeros((len(image_list), *size), dtype=np.float32)#
        for ix, in_ in enumerate(inputs):
            input_[ix] = in_
        return input_